"""Make table showing novels by year and gender."""
import argparse
import functools

import numpy as np
import pandas as pd

import datasets
import inference

parser = argparse.ArgumentParser()
parser.add_argument('output_filename', help='Output path for table.')
parser.add_argument('--decade', action='store_true', help='show novels by decade instead of year.')


@functools.lru_cache()
def dataset_and_draws():
    """Gather dataset and draws from posterior distributions."""
    raven_forster = datasets.raven_forster_1789_1799()
    novels_1800_1836 = datasets.garside_schöwerling_title_counts()
    novels_1800_1829_by_gender = datasets.garside_schöwerling_title_counts_by_gender()
    df = pd.DataFrame(dict(novels=pd.concat([raven_forster, novels_1800_1836])))
    df = df.join(novels_1800_1829_by_gender)

    # pad dataframe to 1919
    df = pd.concat([df, pd.DataFrame(dict(novels=float('nan')), index=range(1837, 1919 + 1))])
    df = df.join(datasets.publishers_circular())
    df = df.join(datasets.andrew_block())

    # add nstc
    df = df.join(datasets.nineteenth_century_short_title_catalogue_loced())

    # add bassett priors
    bassett_discounted = datasets.bassett_at_the_circulating_library_priors_discounted()
    df = df.join(bassett_discounted[['bassett_25_percentile', 'bassett_50_percentile', 'bassett_75_percentile']])

    # add ellen miller casey athenaeum counts
    df = df.join(datasets.casey_athenaeum_novels())

    # add population estimates
    df = df.join(datasets._population_british_isles())

    # add inferences
    fit_extract = inference.sampling()
    return df, fit_extract


def _dataset_years_all_subtask():
    """Combine draws, used in totals calculations."""
    _, fit_extract = dataset_and_draws()
    fit_extract = fit_extract.copy()
    raven_forster_fit_extract = inference.posterior_raven_forster_1789_1799()

    # combine fit_extract with raven forster
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        fit_extract[var_name] = np.hstack([raven_forster_fit_extract[var_name], fit_extract[var_name]])
        assert fit_extract[var_name].shape == (10_000, 120 + 11)
    return fit_extract


def dataset_19th_all():
    """Calculate totals for 1800 to 1899 (inclusive)."""
    # NOTE: code in the `*_all()` functions is very similar.
    fit_extract = _dataset_years_all_subtask()

    # 90% pointwise posterior intervals, goes from 1800 to 1919
    series = pd.Series(name='All')
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        total = fit_extract['y_sim'][:, 11:-20].sum(axis=1)  # exclude years before 1837, after 1901
        draws = fit_extract[var_name][:, 11:-20]  # shape (num_draws, num_years)
        assert draws.shape[1] == 100, draws.shape  # 1800 to 1899, 100 years
        draws = draws.sum(axis=1)  # sum across all years
        assert len(draws) >= 2000, len(draws)  # whatever the number of draws is
        p05, p95 = np.split(pd.Series(draws).quantile([0.05, 0.95]).values, 2)
        pct_p05, pct_p95 = np.split(pd.Series(100 * draws / total).quantile([0.05, 0.95]).values, 2)
        series[f'{var_name}_p05'], series[f'{var_name}_p95'] = p05.squeeze(), p95.squeeze()
        series[f'{var_name}_pct_p05'], series[f'{var_name}_pct_p95'] = pct_p05.squeeze(), pct_p95.squeeze()

    print('totals for 1800-1899:', series)
    return series


def dataset_victorian_all():
    """Calculate totals for 1837 to 1901 (inclusive)."""
    # NOTE: code in the `*_all()` functions is very similar.
    fit_extract = _dataset_years_all_subtask()

    # 90% pointwise posterior intervals, goes from 1837 to 1901
    series = pd.Series(name='All')
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        total = fit_extract['y_sim'][:, 48:-18].sum(axis=1)  # exclude years before 1837, after 1901
        draws = fit_extract[var_name][:, 48:-18]  # shape (num_draws, num_years)
        assert draws.shape[1] == 65, draws.shape  # 1837 to 1901, 65 years
        draws = draws.sum(axis=1)  # sum across all years
        assert len(draws) >= 2000, len(draws)  # whatever the number of draws is
        p05, p95 = np.split(pd.Series(draws).quantile([0.05, 0.95]).values, 2)
        pct_p05, pct_p95 = np.split(pd.Series(100 * draws / total).quantile([0.05, 0.95]).values, 2)
        series[f'{var_name}_p05'], series[f'{var_name}_p95'] = p05.squeeze(), p95.squeeze()
        series[f'{var_name}_pct_p05'], series[f'{var_name}_pct_p95'] = pct_p05.squeeze(), pct_p95.squeeze()

    print('totals for 1837-1901 (victorian):', series)
    return series


def dataset_years_all():
    """Assemble series for final row of year rate table, 1789-1919.

    Calculates 90% intervals for counts and percentages.
    """
    # NOTE: code in the `*_all()` functions is very similar.
    fit_extract = _dataset_years_all_subtask()

    # 90% pointwise posterior intervals, goes from 1789 to 1919
    series = pd.Series(name='All')
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        total = fit_extract['y_sim'].sum(axis=1)
        draws = fit_extract[var_name]  # shape (num_draws, num_years)
        assert draws.shape[1] == 120 + 11, draws.shape  # 1789 to 1919, 131 years
        draws = draws.sum(axis=1)  # sum across all years
        assert len(draws) >= 2000, len(draws)  # whatever the number of draws is
        p05, p95 = np.split(pd.Series(draws).quantile([0.05, 0.95]).values, 2)
        pct_p05, pct_p95 = np.split(pd.Series(100 * draws / total).quantile([0.05, 0.95]).values, 2)
        series[f'{var_name}_p05'], series[f'{var_name}_p95'] = p05.squeeze(), p95.squeeze()
        series[f'{var_name}_pct_p05'], series[f'{var_name}_pct_p95'] = pct_p05.squeeze(), pct_p95.squeeze()
    print('totals for 1789-1919:', series)
    return series


def dataset_decades_all():
    """Assembles series for final row of decade rate table, 1790-1919.

    Calculates 90% intervals for counts and percentages.

    Supplements decades table, so covers 1790 to 1919 (inclusive).

    """
    # NOTE: code in the `*_all()` functions is very similar.
    fit_extract = _dataset_years_all_subtask()

    # 90% pointwise posterior intervals, goes from 1790 to 1919
    series = pd.Series(name='All')
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        total = fit_extract['y_sim'][:, 1:].sum(axis=1)  # exclude 1789 (at index 0)
        draws = fit_extract[var_name][:, 1:]  # shape (num_draws, num_years); the `1:` slice excludes 1789
        assert draws.shape[1] == 120 + 10, draws.shape  # 1790 to 1919, 130 years
        draws = draws.sum(axis=1)  # sum across all years
        assert len(draws) >= 2000, len(draws)  # whatever the number of draws is
        p05, p95 = np.split(pd.Series(draws).quantile([0.05, 0.95]).values, 2)
        pct_p05, pct_p95 = np.split(pd.Series(100 * draws / total).quantile([0.05, 0.95]).values, 2)
        series[f'{var_name}_p05'], series[f'{var_name}_p95'] = p05.squeeze(), p95.squeeze()
        series[f'{var_name}_pct_p05'], series[f'{var_name}_pct_p95'] = pct_p05.squeeze(), pct_p95.squeeze()

    print('totals for 1790-1919 (for decades table):', series)
    return series


def dataset_decades():
    """Assembles dataframe for decade rate table.

    Covers 1790 to 1919 (inclusive).

    Calculates 90% intervals for counts and percentages.

    """
    df, fit_extract = dataset_and_draws()
    fit_extract = fit_extract.copy()
    raven_forster_fit_extract = inference.posterior_raven_forster_1789_1799()

    # combine fit_extract with raven forster
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        fit_extract[var_name] = np.hstack([raven_forster_fit_extract[var_name], fit_extract[var_name]])
        assert fit_extract[var_name].shape == (10_000, 120 + 11)

    # first need to drop 1789 since we are only covering 1780 to 1919
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        fit_extract[var_name] = fit_extract[var_name][:, 1:]
        assert fit_extract[var_name].shape == (10_000, 130)

    # allocate arrays for decade totals
    fit_extract_decades = {}
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        fit_extract_decades[var_name] = np.zeros((10_000, 13))

    # populate arrays
    decade_bins = np.digitize(range(1790, 1919 + 1), range(1800, 1919 + 1, 10))
    assert max(decade_bins) == 12
    for decade_index in range(12 + 1):
        for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
            fit_extract_decades[var_name][:, decade_index] = fit_extract[var_name][:, decade_bins == decade_index].sum(axis=1)
        assert fit_extract_decades[var_name].shape == (10_000, 13), var_name
    assert fit_extract_decades[var_name].min() > 0

    # the number of total novels published between 1790 and 1799 is known and
    # not modeled, so there's no posterior predictive check to do.
    # Here every "sample" for the total number of novels is the known number, 701 novels
    np.testing.assert_array_equal(fit_extract_decades['y_sim'][:, 0], 701)  # 701 novels, known to be published 1790 to 1799

    # add percentages
    y_sim_decades = fit_extract_decades['y_sim']
    assert y_sim_decades.shape == (10_000, 13)
    for var_name in ['y_unknown_sim_pct', 'y_men_sim_pct', 'y_women_sim_pct']:
        fit_extract_decades[var_name] = 100 * fit_extract_decades[var_name.replace('_pct', '')] / y_sim_decades
        assert fit_extract_decades[var_name].shape == (10_000, 13)
        assert fit_extract_decades[var_name].min() > 0
        assert fit_extract_decades[var_name].max() < 100

    # construct data frame
    df = pd.DataFrame(index=range(1790, 1910 + 1, 10))
    assert len(df) == 13

    # need percentages for everything, but i'll ignore 1800s, 1810s, 1820s
    # so there's just a single special case with the percentages, 1790s, where
    # we know the total (so y_sim is a constant)

    # populate the DataFrame with counts
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        np.testing.assert_equal(pd.DataFrame(fit_extract_decades[var_name]).quantile([0.05, 0.95], axis=0).shape, (2, 13))
        p05, p95 = np.split(pd.DataFrame(fit_extract_decades[var_name]).quantile([0.05, 0.95], axis=0).values, 2)
        np.testing.assert_equal(p05.ravel().shape, (13,))
        df.loc[:, f'{var_name}_p05'], df.loc[:, f'{var_name}_p95'] = p05.ravel(), p95.ravel()

    # populate the DataFrame with percentages
    for var_name in ['y_unknown_sim_pct', 'y_men_sim_pct', 'y_women_sim_pct']:
        p05, p25, p75, p95 = np.split(pd.DataFrame(fit_extract_decades[var_name]).quantile([0.05, 0.25, 0.75, 0.95], axis=0).values, 4)
        np.testing.assert_equal(p05.ravel().shape, (13,))
        df.loc[:, f'{var_name}_p05'], df.loc[:, f'{var_name}_p95'] = p05.ravel(), p95.ravel()
        df.loc[:, f'{var_name}_p25'], df.loc[:, f'{var_name}_p75'] = p25.ravel(), p75.ravel()

    return df


def dataset_years():
    """Assembles dataframe for yearly rate table."""
    df, fit_extract = dataset_and_draws()
    fit_extract = fit_extract.copy()
    raven_forster_fit_extract = inference.posterior_raven_forster_1789_1799()

    # combine fit_extract with raven forster
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        fit_extract[var_name] = np.hstack([raven_forster_fit_extract[var_name], fit_extract[var_name]])
        assert fit_extract[var_name].shape == (10_000, 120 + 11)

    # 90% pointwise posterior intervals, goes from 1789 to 1919
    for var_name in ['y_sim', 'y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        p05, p25, p50, p75, p95 = np.split(pd.DataFrame(fit_extract[var_name]).quantile([0.05, 0.25, 0.5, 0.75, 0.95], axis=0).values, 5)
        (df.loc[1789:1919, f'{var_name}_p05'], df.loc[1789:1919, f'{var_name}_p25'],
         df.loc[1789:1919, f'{var_name}_p50'], df.loc[1789:1919, f'{var_name}_p75'],
         df.loc[1789:1919, f'{var_name}_p95']) = p05.ravel(), p25.ravel(), p50.ravel(), p75.ravel(), p95.ravel()

    # 90% pointwise posterior intervals, percentages
    for var_name in ['y_unknown_sim', 'y_men_sim', 'y_women_sim']:
        p05, p25, p50, p75, p95 = np.split(pd.DataFrame(100 * fit_extract[var_name] / fit_extract['y_sim']).quantile([0.05, 0.25, 0.50, 0.75, 0.95], axis=0).values, 5)
        (df.loc[1789:1919, f'{var_name}_pct_p05'], df.loc[1789:1919, f'{var_name}_pct_p25'],
         df.loc[1789:1919, f'{var_name}_pct_p50'], df.loc[1789:1919, f'{var_name}_pct_p75'],
         df.loc[1789:1919, f'{var_name}_pct_p95']) = p05.ravel(), p25.ravel(), p50.ravel(), p75.ravel(), p95.ravel()

    # ratio of men to known
    p05, p25, p50, p75, p95 = np.split(pd.DataFrame(100 * fit_extract['y_men_sim'] / (fit_extract['y_men_sim'] + fit_extract['y_women_sim'])).quantile([0.05, 0.25, 0.50, 0.75, 0.95], axis=0).values, 5)
    (df.loc[1789:1919, f'men_of_known_pct_sim_p05'], df.loc[1789:1919, f'men_of_known_pct_sim_p25'],
        df.loc[1789:1919, f'men_of_known_pct_sim_p50'], df.loc[1789:1919, f'men_of_known_pct_sim_p75'],
        df.loc[1789:1919, f'men_of_known_pct_sim_p95']) = p05.ravel(), p25.ravel(), p50.ravel(), p75.ravel(), p95.ravel()

    assert int(df.loc[1789, 'y_unknown_sim_p05']) > 1
    return df


def make_table_by_decade(filename):
    df = dataset_decades()

    # total novels are known for 1789 to 1836, so 1790s, 1800s, 1810s, and 1820s should not be intervals
    # 1790s is already not an interval
    df.loc[1800, 'y_sim_p05'] = df.loc[1800, 'y_sim_p95'] = 778
    df.loc[1810, 'y_sim_p05'] = df.loc[1810, 'y_sim_p95'] = 669
    df.loc[1820, 'y_sim_p05'] = df.loc[1820, 'y_sim_p95'] = 829

    # gender is known for 1800 to 1836, so 1800s, 1810s, 1820s should not be intervals
    df.loc[1800, 'y_women_sim_p05'] = df.loc[1800, 'y_women_sim_p95'] = 366
    df.loc[1810, 'y_women_sim_p05'] = df.loc[1810, 'y_women_sim_p95'] = 346
    df.loc[1820, 'y_women_sim_p05'] = df.loc[1820, 'y_women_sim_p95'] = 289

    df.loc[1800, 'y_men_sim_p05'] = df.loc[1800, 'y_men_sim_p95'] = 298
    df.loc[1810, 'y_men_sim_p05'] = df.loc[1810, 'y_men_sim_p95'] = 197
    df.loc[1820, 'y_men_sim_p05'] = df.loc[1820, 'y_men_sim_p95'] = 426

    df.loc[1800, 'y_unknown_sim_p05'] = df.loc[1800, 'y_unknown_sim_p95'] = 114
    df.loc[1810, 'y_unknown_sim_p05'] = df.loc[1810, 'y_unknown_sim_p95'] = 126
    df.loc[1820, 'y_unknown_sim_p05'] = df.loc[1820, 'y_unknown_sim_p95'] = 114

    def make_interval(row, force_int=False):
        if pd.isnull(row).any():
            return ''
        if force_int:
            row = row.astype(int)
        low, high = row
        return f'{round(low, 2):,}' if low == high else f'{round(low, 2):,}-{round(high, 2):,}'

    def make_interval_with_pct(row, force_int=False):
        if pd.isnull(row).any():
            return ''
        if force_int:
            row = row.astype(int)
        low, high, low_pct, high_pct = row
        return f'{round(low, 2):,} ({round(low_pct)}%)' if low == high else f'{round(low, 2):,}-{round(high, 2):,} ({round(low_pct)}-{round(high_pct)}%)'

    df.index = [f'{year}-{int(year) + 9}' for year in df.index]
    df = df.append(dataset_decades_all())  # add the final row

    # assemble the formatted table
    df_formatted = pd.DataFrame(index=df.index)
    # add columns for each gender category
    for col in ['y_men_sim', 'y_women_sim', 'y_unknown_sim']:
        cols_oi = [f'{col}_p05', f'{col}_p95', f'{col}_pct_p05', f'{col}_pct_p95']
        df_formatted[col] = df[cols_oi].apply(lambda row: make_interval_with_pct(row, force_int=True), axis=1)  # noqa

    # add columns for total count, no percentage
    for col in ['y_sim']:
        cols_oi = [f'{col}_p05', f'{col}_p95']
        df_formatted[col] = df[cols_oi].apply(lambda row: make_interval(row, force_int=True), axis=1)  # noqa

    df_formatted = df_formatted.rename(columns={
        'y_sim': 'N',
        'y_men_sim': 'Men-authored',
        'y_women_sim': 'Women-authored',
        'y_unknown_sim': 'Unknown',
    })

    with open(filename, 'w') as fh:
        fh.write(df_formatted.to_latex(index=True))


def make_table_by_year(filename):
    """Make table showing number of new titles published by year and gender.

    Covers the entire period, 1789 to 1919 (inclusive).

    """
    df = dataset_years()

    columns_oi = ['novels', 'novels_male', 'novels_female', 'novels_unknown',
                  'y_sim_p05', 'y_sim_p95', 'y_unknown_sim_p05',
                  'y_unknown_sim_p95', 'y_men_sim_p05', 'y_men_sim_p95',
                  'y_women_sim_p05', 'y_women_sim_p95']
    df = df[columns_oi]
    df = df.append(dataset_years_all())  # add the final row

    df = df.rename(columns={
        'y_sim_p05': 'Novels, All p05',
        'y_sim_p95': 'Novels, All p95',
        'y_unknown_sim_p05': 'Novels, Unknown p05',
        'y_unknown_sim_p95': 'Novels, Unknown p95',
        'y_men_sim_p05': 'Novels, Men-authored p05',
        'y_men_sim_p95': 'Novels, Men-authored p95',
        'y_women_sim_p05': 'Novels, Women-authored p05',
        'y_women_sim_p95': 'Novels, Women-authored p95',
        'novels': 'Novels, All (RFGS)',
        'novels_male': 'Novels, Men-authored (RFGS)',
        'novels_female': 'Novels, Women-authored (RFGS)',
        'novels_unknown': 'Novels, Unknown (RFGS)',
    })

    desired_column_order = [
        'Novels, All p05', 'Novels, All p95', 'Novels, Unknown p05', 'Novels, Unknown p95',
        'Novels, Men-authored p05', 'Novels, Men-authored p95', 'Novels, Women-authored p05', 'Novels, Women-authored p95',
        'Novels, All (RFGS)', 'Novels, Men-authored (RFGS)', 'Novels, Women-authored (RFGS)', 'Novels, Unknown (RFGS)',
    ]
    df = df.reindex(columns=desired_column_order)

    def make_interval(row, force_int=False):
        if pd.isnull(row).any():
            return ''
        if force_int:
            row = row.astype(int)
        low, high = row
        return f'{round(low, 2):,}' if low == high else f'{round(low, 2):,}-{round(high, 2):,}'

    model_columns = ['Novels, All', 'Novels, Women-authored', 'Novels, Men-authored', 'Novels, Unknown']
    for column in model_columns:
        df[column] = df[[f'{column} p05', f'{column} p95']].apply(lambda row: make_interval(row, force_int=True), axis=1)  # noqa
        del df[f'{column} p05']
        del df[f'{column} p95']
    for column in df.columns.drop(model_columns):
        df[column] = df[column].apply(lambda el: '' if pd.isnull(el) else '{:,}'.format(int(el)))

    # reorder one last time
    desired_column_order = [
        'Novels, Men-authored', 'Novels, Women-authored', 'Novels, Unknown', 'Novels, All',
        'Novels, Men-authored (RFGS)', 'Novels, Women-authored (RFGS)', 'Novels, Unknown (RFGS)', 'Novels, All (RFGS)',
    ]
    df = df.reindex(columns=desired_column_order)

    # overwrite intervals with known data with an asterisk
    # indexes are now objects because there is a string 'All' row
    years_rfgs_gender = [year for year in range(1800, 1829 + 1)]
    df.loc[years_rfgs_gender, ['Novels, Men-authored', 'Novels, Women-authored', 'Novels, Unknown', 'Novels, All']] = \
        df.loc[years_rfgs_gender, ['Novels, Men-authored (RFGS)', 'Novels, Women-authored (RFGS)', 'Novels, Unknown (RFGS)', 'Novels, All (RFGS)']].values
    years_rfgs_no_gender = [year for year in range(1830, 1836 + 1)]
    df.loc[years_rfgs_no_gender, 'Novels, All'] = df.loc[years_rfgs_no_gender, 'Novels, All (RFGS)'].values

    # remove the RFGS columns
    columns_oi = ['Novels, Men-authored', 'Novels, Women-authored', 'Novels, Unknown', 'Novels, All']
    df = df[columns_oi]

    # longtable caption must go inside the longtable environment (with no `table` enclosing environment
    longtable_latex = df.to_latex(index=True, longtable=True, escape=False)
    longtable_caption = r"""\caption{\textbf{New novels published between 1789 and 1919.} Intervals show 90\% credible
        intervals. Percentages are calculated with respect to table rows. Where intervals do not appear (1789--1836),
        counts shown are from RFGS. RFGS provide counts of new novels by author gender for 1800-1829 and total new novels for all years between 1789 and 1836.\label{tbl:novels-by-year}}"""
    longtable_latex_lines = longtable_latex.splitlines()
    with open(filename, 'w') as fh:
        fh.write('\n'.join(longtable_latex_lines[:-1]))
        fh.write('\n\n' + longtable_caption + '\n\n')
        fh.write(longtable_latex_lines[-1])

if __name__ == '__main__':
    args = parser.parse_args()
    if args.decade:
        make_table_by_decade(args.output_filename)
    else:
        make_table_by_year(args.output_filename)
